week 6: multilevel models

multilevel tadpoles

multilevel models

We’re starting our unit on multilevel models, which can be thought of as models that “remember” features of clusters of data as they learn about all the clusters. The model will pool information across clusters (e.g., our estimates about cluster A will be informed in part by clusters B, C, and D). This tends to improve estimates about each cluster. Here are some other benefits of multilevel modeling:

  1. improved estimates for repeated sampling. If you try to fit a single-level model to these data, you’ll over- or under-fit the data.
  2. improved estimates for imbalance in sampling. prevent over-sampled clusters from dominating inference, while also balancing the fact that larger clusters have more information.
  3. estimates of variation. model variation explicitly!
  4. avoid averaging, retain variation. averaging manufactures false confidence (artificially inflates precision) and introduces arbitrary data transformations.

Multilevel modeling should be your default approach.

example: tadpoles

data(reedfrogs, package = "rethinking")
d<- reedfrogs
str(d)
'data.frame':   48 obs. of  5 variables:
 $ density : int  10 10 10 10 10 10 10 10 10 10 ...
 $ pred    : Factor w/ 2 levels "no","pred": 1 1 1 1 1 1 1 1 2 2 ...
 $ size    : Factor w/ 2 levels "big","small": 1 1 1 1 2 2 2 2 1 1 ...
 $ surv    : int  9 10 7 10 9 9 10 9 4 9 ...
 $ propsurv: num  0.9 1 0.7 1 0.9 0.9 1 0.9 0.4 0.9 ...

Each row is a tank. These tanks are examples of clusters. We’ll need to create an index variable for each tank.

d$tank <- 1:nrow(d)
d %>% count(density)
  density  n
1      10 16
2      25 16
3      35 16

How would you fit a single-level model in which you estimate the survival rate (surv) of tadpoles separately for each tank?

  • What is the distribution of the outcome variable?
  • What is the formula?
  • What is/are your prior(s)?

\[\begin{align*} S_i &\sim \text{Binomial}(N_i, p_i) \\ \text{logit}(p_i) &= \alpha_{tank} \\ \alpha_j &\sim \text{Normal}(0,1.5) \end{align*}\]

m1 <- 
  brm(data = d, 
      family = binomial,
      surv | trials(density) ~ 0 + factor(tank),
      prior(normal(0, 1.5), class = b),
      iter = 2000, warmup = 1000, chains = 4, cores = 4,
      seed = 13,
      file = here("files/data/generated_data/m62.1"))
m1
 Family: binomial 
  Links: mu = logit 
Formula: surv | trials(density) ~ 0 + factor(tank) 
   Data: d (Number of observations: 48) 
  Draws: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
         total post-warmup draws = 4000

Regression Coefficients:
             Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
factortank1      1.72      0.75     0.36     3.28 1.00     6727     2922
factortank2      2.39      0.87     0.85     4.24 1.00     5009     2249
factortank3      0.75      0.63    -0.45     2.06 1.00     5975     2721
factortank4      2.41      0.90     0.84     4.40 1.00     5413     2712
factortank5      1.72      0.78     0.34     3.36 1.00     6346     2332
factortank6      1.73      0.75     0.36     3.39 1.00     5673     2882
factortank7      2.40      0.87     0.87     4.24 1.00     6013     2773
factortank8      1.71      0.76     0.32     3.29 1.00     5652     2920
factortank9     -0.37      0.61    -1.60     0.80 1.00     5643     2887
factortank10     1.70      0.75     0.39     3.31 1.00     6262     2260
factortank11     0.74      0.63    -0.46     2.05 1.00     6269     2696
factortank12     0.39      0.63    -0.82     1.64 1.00     5644     2858
factortank13     0.76      0.66    -0.46     2.12 1.00     5430     2848
factortank14     0.01      0.61    -1.16     1.22 1.00     6416     3038
factortank15     1.72      0.76     0.34     3.34 1.00     5830     2712
factortank16     1.72      0.78     0.36     3.42 1.00     5907     2726
factortank17     2.54      0.67     1.36     4.01 1.00     4795     2497
factortank18     2.14      0.61     1.05     3.44 1.00     6054     2619
factortank19     1.80      0.54     0.83     2.92 1.00     6042     3107
factortank20     3.09      0.79     1.72     4.82 1.00     5402     2768
factortank21     2.15      0.62     1.05     3.49 1.00     5870     2806
factortank22     2.14      0.57     1.12     3.36 1.00     5687     2898
factortank23     2.13      0.59     1.10     3.41 1.00     5338     3008
factortank24     1.55      0.51     0.61     2.60 1.00     6469     3086
factortank25    -1.11      0.45    -2.04    -0.26 1.00     5811     2806
factortank26     0.08      0.38    -0.65     0.81 1.00     5386     3054
factortank27    -1.54      0.49    -2.54    -0.63 1.00     5931     2981
factortank28    -0.55      0.40    -1.36     0.20 1.00     5774     2862
factortank29     0.07      0.40    -0.72     0.84 1.00     5880     2852
factortank30     1.32      0.48     0.44     2.30 1.00     5203     2097
factortank31    -0.72      0.42    -1.55     0.09 1.00     7194     3016
factortank32    -0.39      0.42    -1.23     0.40 1.00     5849     2924
factortank33     2.85      0.67     1.71     4.33 1.00     5415     2328
factortank34     2.47      0.59     1.42     3.76 1.00     5835     2769
factortank35     2.46      0.57     1.46     3.68 1.00     5369     2660
factortank36     1.91      0.49     1.02     2.97 1.00     6166     2748
factortank37     1.91      0.49     1.00     2.93 1.00     6123     2860
factortank38     3.37      0.77     2.05     5.04 1.00     5313     2355
factortank39     2.46      0.58     1.43     3.72 1.00     6008     2835
factortank40     2.16      0.53     1.21     3.32 1.00     5945     2463
factortank41    -1.91      0.49    -2.95    -1.02 1.00     6293     2323
factortank42    -0.63      0.35    -1.32     0.04 1.00     6646     3016
factortank43    -0.51      0.34    -1.19     0.16 1.00     5326     3035
factortank44    -0.39      0.33    -1.05     0.24 1.00     6226     3190
factortank45     0.52      0.35    -0.15     1.22 1.00     7301     2643
factortank46    -0.63      0.35    -1.34     0.04 1.00     5798     2948
factortank47     1.91      0.49     1.03     2.91 1.00     5941     2893
factortank48    -0.07      0.34    -0.74     0.59 1.00     7503     2958

Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).

Now we’ll consider the multilevel alternative.

\[\begin{align*} S_i &\sim \text{Binomial}(N_i, p_i) \\ \text{logit}(p_i) &= \alpha_{tank} \\ \alpha_j &\sim \text{Normal}(\bar{\alpha},\sigma) \\ \bar{\alpha} &\sim \text{Normal}(0, 1.5) \\ \sigma &\sim \text{Exponential}(1) \end{align*}\]

m2 <- 
  brm(data = d, 
      family = binomial,
      surv | trials(density) ~ 1 + (1 | tank),
      prior = c( prior(normal(0, 1.5), class = Intercept), # alpha bar
                 prior(exponential(1), class = sd)),       # sigma
      iter = 2000, warmup = 1000, chains = 4, cores = 4,
      seed = 13,
      file = here("files/data/generated_data/m62.2"))

The syntax for the varying effects follows the lme4 style, ( <varying parameter(s)> | <grouping variable(s)> ). In this case (1 | tank) indicates only the intercept, 1, varies by tank. The extent to which parameters vary is controlled by the prior, prior(exponential(1), class = sd), which is parameterized in the standard deviation metric.

Compare our two models.

m1 <- add_criterion(m1, "waic")
m2 <- add_criterion(m2, "waic")

w <- loo_compare(m1, m2, criterion = "waic")

print(w, simplify = F)
   elpd_diff se_diff elpd_waic se_elpd_waic p_waic se_p_waic waic   se_waic
m2    0.0       0.0   -99.7       3.7         20.6    0.8     199.4    7.3 
m1   -7.3       1.8  -107.0       2.3         25.3    1.2     214.0    4.7 

exercise

  1. Fit the multilevel tadpoles model.

  2. In the video lecture, McElreath demonstrates that cross-validation can be used to determine that an appropriate prior for sigma is about 1.8. Recreate this demonstration. (Start with just 3 values of \(\sigma\) so you can finish this in class. You can update later with more values for fun.)

solution

sigma_seq <- exp(seq(from=log(0.1), to=log(5), len=20))
waic_values <- numeric(length(sigma_seq))
loo_values <- numeric(length(sigma_seq))

for(i in 1:length(sigma_seq)) {
  sigma_value <- sigma_seq[i]
  cat("Fitting model", i, "with prior rate =", sigma_value, "\n")
  
  # Create the prior with the specific numeric value
  my_prior <- c(
    set_prior("normal(0, 1.5)", class = "Intercept"),
    set_prior(paste0("exponential(", sigma_value, ")"), class = "sd")
  )
  
  # Fit the model with this prior
  model <- brm(
    data = d, 
    family = binomial,
    surv | trials(density) ~ 1 + (1 | tank),
    prior = my_prior,
    iter = 1000, warmup = 500, chains = 4, cores = 4,
    seed = 13,
    file = here(paste0("files/data/generated_data/m62_prior_", i))
  )
  
  # Calculate WAIC
  model <- add_criterion(model, "waic")
  waic_values[i] <- model$criteria$waic$estimates["waic", "Estimate"]
  # Calculate WAIC
  model <- add_criterion(model, "loo")
  loo_values[i] <- model$criteria$loo$estimates["elpd_loo", "Estimate"]
  
  
  # Optional: clean up to save memory
  rm(model)
  gc()
}
Fitting model 1 with prior rate = 0.1 
Fitting model 2 with prior rate = 0.1228625 
Fitting model 3 with prior rate = 0.150952 
Fitting model 4 with prior rate = 0.1854635 
Fitting model 5 with prior rate = 0.2278651 
Fitting model 6 with prior rate = 0.2799609 
Fitting model 7 with prior rate = 0.3439671 
Fitting model 8 with prior rate = 0.4226066 
Fitting model 9 with prior rate = 0.5192252 
Fitting model 10 with prior rate = 0.6379333 
Fitting model 11 with prior rate = 0.783781 
Fitting model 12 with prior rate = 0.9629732 
Fitting model 13 with prior rate = 1.183133 
Fitting model 14 with prior rate = 1.453628 
Fitting model 15 with prior rate = 1.785964 
Fitting model 16 with prior rate = 2.19428 
Fitting model 17 with prior rate = 2.695948 
Fitting model 18 with prior rate = 3.312311 
Fitting model 19 with prior rate = 4.069589 
Fitting model 20 with prior rate = 5 

solution

data.frame(
  sigma = sigma_seq,
  loo = loo_values
) %>% 
  ggplot( aes( x=sigma, y=loo )) + 
  geom_point( size=2, color = "#1c5253" ) +
  geom_line(alpha = .5) 

exercise

Check your trace plots to evaluate whether your chains have mixed appropriately.

solution

plot(m2)

Here’s an additional packages – bayesplot – to help you visualize chains.

library(bayesplot)

as_draws_df(m2) %>% 
  select(contains("tank"), .chain) %>% 
  select(1:8, chain=.chain) %>% 
  mcmc_trace(facet_args = list(ncol=4))

Here’s an additional packages – bayesplot – to help you visualize chains.

library(bayesplot)

as_draws_df(m2) %>% 
  select(contains("tank"), .chain) %>% 
  select(9:16, chain = .chain) %>% 
  mcmc_rank_overlay(facet_args = list(ncol=4, scales="free"))
m2
 Family: binomial 
  Links: mu = logit 
Formula: surv | trials(density) ~ 1 + (1 | tank) 
   Data: d (Number of observations: 48) 
  Draws: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
         total post-warmup draws = 4000

Multilevel Hyperparameters:
~tank (Number of levels: 48) 
              Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sd(Intercept)     1.63      0.21     1.27     2.09 1.01     1094     1936

Regression Coefficients:
          Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
Intercept     1.37      0.26     0.88     1.88 1.00      670     1109

Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).

Our summary output doesn’t show us the individual intercepts for each tank. No worries – let’s draw from our posterior!

post <- as_draws_df(m2)

str(post)
draws_df [4,000 × 56] (S3: draws_df/draws/tbl_df/tbl/data.frame)
 $ b_Intercept         : num [1:4000] 1.27 1.02 1.28 1.44 1.33 ...
 $ sd_tank__Intercept  : num [1:4000] 1.7 1.55 1.75 1.7 1.47 ...
 $ Intercept           : num [1:4000] 1.27 1.02 1.28 1.44 1.33 ...
 $ r_tank[1,Intercept] : num [1:4000] 1.211 2.58 0.217 1.898 1.411 ...
 $ r_tank[2,Intercept] : num [1:4000] 1.64 1.76 1.19 2.31 3.99 ...
 $ r_tank[3,Intercept] : num [1:4000] -0.55592 -0.18093 0.00559 -0.61033 -0.08098 ...
 $ r_tank[4,Intercept] : num [1:4000] 2.729 0.896 2.576 0.793 2.261 ...
 $ r_tank[5,Intercept] : num [1:4000] 1.053 1.075 0.858 0.59 -0.347 ...
 $ r_tank[6,Intercept] : num [1:4000] -0.4 0.108 1.388 0.124 1.565 ...
 $ r_tank[7,Intercept] : num [1:4000] 2.302 2.346 0.905 2.419 1.822 ...
 $ r_tank[8,Intercept] : num [1:4000] 1.377 1.097 0.416 1.486 -1.092 ...
 $ r_tank[9,Intercept] : num [1:4000] -1.326 0.306 -3.024 -0.159 -1.421 ...
 $ r_tank[10,Intercept]: num [1:4000] 1.202 2.468 0.906 0.848 0.471 ...
 $ r_tank[11,Intercept]: num [1:4000] -1.31 0.611 -1.164 0.585 0.397 ...
 $ r_tank[12,Intercept]: num [1:4000] -1.715 -0.502 -0.11 -1.239 -0.412 ...
 $ r_tank[13,Intercept]: num [1:4000] -0.31 0.67 -0.889 0.205 -1.14 ...
 $ r_tank[14,Intercept]: num [1:4000] -1.31 -1.01 -1.55 -0.88 -1.19 ...
 $ r_tank[15,Intercept]: num [1:4000] 0.57 1.891 1.159 0.32 0.239 ...
 $ r_tank[16,Intercept]: num [1:4000] 1.511 1.037 1.401 0.049 1.184 ...
 $ r_tank[17,Intercept]: num [1:4000] 3.063 2.876 0.834 2.425 1.553 ...
 $ r_tank[18,Intercept]: num [1:4000] 1.581 2.36 1.379 0.662 0.531 ...
 $ r_tank[19,Intercept]: num [1:4000] 0.0844 1.1413 1.5953 0.4047 0.5816 ...
 $ r_tank[20,Intercept]: num [1:4000] 1.49 1.5 2.99 2.11 1.61 ...
 $ r_tank[21,Intercept]: num [1:4000] -0.2576 1.6767 0.7497 2.4893 -0.0909 ...
 $ r_tank[22,Intercept]: num [1:4000] 1.6086 0.0385 2.624 -0.1048 1.2849 ...
 $ r_tank[23,Intercept]: num [1:4000] 1.011 2.16 0.287 1.959 2.075 ...
 $ r_tank[24,Intercept]: num [1:4000] 0.8702 0.727 0.0215 0.8118 -0.2674 ...
 $ r_tank[25,Intercept]: num [1:4000] -2.47 -1.91 -2.38 -2.18 -2.28 ...
 $ r_tank[26,Intercept]: num [1:4000] -0.713 -0.915 -1.589 -0.713 -1.806 ...
 $ r_tank[27,Intercept]: num [1:4000] -3.06 -2.61 -2.43 -3.04 -2 ...
 $ r_tank[28,Intercept]: num [1:4000] -2.33 -1.22 -1.5 -1.94 -2.16 ...
 $ r_tank[29,Intercept]: num [1:4000] -0.803 -1.128 -0.677 -1.192 -1.635 ...
 $ r_tank[30,Intercept]: num [1:4000] -0.041 0.569 0.599 -0.531 0.126 ...
 $ r_tank[31,Intercept]: num [1:4000] -2.41 -1.49 -1.7 -2.16 -1.38 ...
 $ r_tank[32,Intercept]: num [1:4000] -1.69 -1.77 -1.69 -1.67 -1.95 ...
 $ r_tank[33,Intercept]: num [1:4000] 1.33 1.16 2.78 1.26 2.13 ...
 $ r_tank[34,Intercept]: num [1:4000] 1.95 1.7 1.68 1.01 1.11 ...
 $ r_tank[35,Intercept]: num [1:4000] 3.12 2.16 1.04 2.38 1.2 ...
 $ r_tank[36,Intercept]: num [1:4000] 0.6551 -0.1295 2.1778 -0.0864 1.629 ...
 $ r_tank[37,Intercept]: num [1:4000] 0.844 1.263 0.975 0.661 0.41 ...
 $ r_tank[38,Intercept]: num [1:4000] 1.33 1.76 4.25 1.25 2.73 ...
 $ r_tank[39,Intercept]: num [1:4000] -0.0802 2.2216 0.7636 2.1258 1.6694 ...
 $ r_tank[40,Intercept]: num [1:4000] 0.298 1.411 0.884 2.038 0.373 ...
 $ r_tank[41,Intercept]: num [1:4000] -3.11 -3.57 -3.67 -3.44 -2.82 ...
 $ r_tank[42,Intercept]: num [1:4000] -1.78 -1.73 -2.16 -1.65 -2.43 ...
 $ r_tank[43,Intercept]: num [1:4000] -1.32 -1.54 -2.07 -1.82 -1.75 ...
 $ r_tank[44,Intercept]: num [1:4000] -1.86 -1.54 -1.62 -1.91 -1.22 ...
 $ r_tank[45,Intercept]: num [1:4000] -1.086 -0.611 -0.897 -0.937 -0.698 ...
 $ r_tank[46,Intercept]: num [1:4000] -1.6 -1.98 -1.69 -1.77 -2.15 ...
 $ r_tank[47,Intercept]: num [1:4000] 1.119 0.726 1.031 0.843 0.861 ...
 $ r_tank[48,Intercept]: num [1:4000] -1.55 -1.311 -0.288 -2.261 -0.636 ...
 $ lprior              : num [1:4000] -3.39 -3.1 -3.44 -3.49 -3.18 ...
 $ lp__                : num [1:4000] -156 -160 -156 -156 -162 ...
 $ .chain              : int [1:4000] 1 1 1 1 1 1 1 1 1 1 ...
 $ .iteration          : int [1:4000] 1 2 3 4 5 6 7 8 9 10 ...
 $ .draw               : int [1:4000] 1 2 3 4 5 6 7 8 9 10 ...

Each of the tank-specific intercepts here represent how the tank intercept deviates from the grand mean (b_Intercept). We also have to remember that these values are in logits. If we want to see these as probabities, we have to back-transform.

modeled_prop = post %>% 
  select(b_Intercept, starts_with("r_tank")) %>% 
  pivot_longer(-b_Intercept, values_to = "logit") %>%
  mutate(logit = logit + b_Intercept,
    prob = inv_logit_scaled(logit)) %>% 
  group_by(name) 

modeled_prop %>% 
  median_qi(prob)
# A tibble: 48 × 7
   name                  prob .lower .upper .width .point .interval
   <chr>                <dbl>  <dbl>  <dbl>  <dbl> <chr>  <chr>    
 1 r_tank[1,Intercept]  0.890  0.655  0.982   0.95 median qi       
 2 r_tank[10,Intercept] 0.888  0.659  0.981   0.95 median qi       
 3 r_tank[11,Intercept] 0.726  0.432  0.911   0.95 median qi       
 4 r_tank[12,Intercept] 0.639  0.349  0.866   0.95 median qi       
 5 r_tank[13,Intercept] 0.727  0.445  0.917   0.95 median qi       
 6 r_tank[14,Intercept] 0.549  0.264  0.809   0.95 median qi       
 7 r_tank[15,Intercept] 0.889  0.645  0.984   0.95 median qi       
 8 r_tank[16,Intercept] 0.889  0.652  0.981   0.95 median qi       
 9 r_tank[17,Intercept] 0.945  0.827  0.991   0.95 median qi       
10 r_tank[18,Intercept] 0.912  0.772  0.979   0.95 median qi       
# ℹ 38 more rows

exercise

We’ll recreate another figure of McElreath’s that shows the relationship between the observed probability of survival for each tank and the predicted probabilty. Here’s what your figure should have:

  1. Survival probability (y-axis) for each tank (x-axis)
  2. Empty circular points representing the model’s estimated probability for each tank
  3. Filled teal/green points showing the observed proportions in each tank
  4. A horizontal line showing the average proportion across all tanks

OPTIONAL

  1. Vertical dashed lines separating the three tank size categories
  2. Text labels for each tank size group
  3. Something to indicate the posterior distribution of survival for each tank.

solution

Code
observed_prop = d %>% 
  mutate(prob = surv/density)

modeled_prop_m = modeled_prop %>% 
  median_qi(prob) %>% 
  mutate( tank = str_extract(name, "[0-9]{1,2}"),
          tank = as.numeric(tank) ) 

avg_prop = inv_logit_scaled(mean(post$b_Intercept))

modeled_prop %>% 
  mutate( tank = str_extract(name, "[0-9]{1,2}"),
          tank = as.numeric(tank) ) %>% 
  ggplot( aes(x=tank, y=prob) ) + 
  stat_slab(alpha=.7) +
  geom_point( data = modeled_prop_m, 
              shape = 1,
              size=2) +
  geom_point(data = observed_prop, 
              color = "#5e8485",
             size=2) +
  # lines separating small, med, and large tanks
  geom_vline(xintercept = c(16.5, 32.5), 
             linetype = "dashed",
             linewidth = 1/4, color = "grey25") +
  geom_hline(yintercept = avg_prop, linewidth = .5, color = "#5e8485") +
  annotate(geom = "text", 
           x = c(8, 16 + 8, 32 + 8), y = 0, 
           label = c("small tanks", "medium tanks", "large tanks")) +
  annotate(geom = "text", 
           x = 46, y = avg_prop+.03, 
           label = c("average proportion"),
           color = "#5e8485", ) +
  labs(
    x="tank",
    y=NULL,
    title="probability of survival"
  )